{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c8d01073-09a8-4c68-b93b-aef463738bd0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:11:37.959402Z",
     "iopub.status.busy": "2022-06-14T17:11:37.958831Z",
     "iopub.status.idle": "2022-06-14T17:11:39.305371Z",
     "shell.execute_reply": "2022-06-14T17:11:39.304839Z",
     "shell.execute_reply.started": "2022-06-14T17:11:37.959353Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "import os, sys, glob, argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "%pylab inline\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "70d2263d-916c-431e-aca7-3e5bcd867f9b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:12:22.662154Z",
     "iopub.status.busy": "2022-06-14T17:12:22.661569Z",
     "iopub.status.idle": "2022-06-14T17:12:22.695854Z",
     "shell.execute_reply": "2022-06-14T17:12:22.695280Z",
     "shell.execute_reply.started": "2022-06-14T17:12:22.662105Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import nibabel as nib\n",
    "from nibabel.viewers import OrthoSlicer3D\n",
    "\n",
    "train_path = glob.glob('./中国农业大学-作物引导拍照挑战赛公开数据/train/*')\n",
    "test_path = glob.glob('./中国农业大学-作物引导拍照挑战赛公开数据/test/*')\n",
    "\n",
    "# np.random.shuffle(train_path)\n",
    "# np.random.shuffle(test_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "431a9196-30ec-4f86-8da7-cd6e2edc7154",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:13:49.622652Z",
     "iopub.status.busy": "2022-06-14T17:13:49.621998Z",
     "iopub.status.idle": "2022-06-14T17:13:49.630524Z",
     "shell.execute_reply": "2022-06-14T17:13:49.630023Z",
     "shell.execute_reply.started": "2022-06-14T17:13:49.622602Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_df = pd.read_csv('./中国农业大学-作物引导拍照挑战赛公开数据/train.csv')\n",
    "train_path = './中国农业大学-作物引导拍照挑战赛公开数据/train/' + train_df['image']\n",
    "train_label = train_df['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8f5799d1-cab6-4f3e-8ba0-7ffca838e578",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:15:14.625731Z",
     "iopub.status.busy": "2022-06-14T17:15:14.625381Z",
     "iopub.status.idle": "2022-06-14T17:15:14.630547Z",
     "shell.execute_reply": "2022-06-14T17:15:14.629918Z",
     "shell.execute_reply.started": "2022-06-14T17:15:14.625706Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 4,  7,  6,  0, 10,  3,  1,  9,  8,  2,  5])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_label.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "584f2188-cf45-4d72-abee-4dae0d362926",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:20:38.672419Z",
     "iopub.status.busy": "2022-06-14T17:20:38.671820Z",
     "iopub.status.idle": "2022-06-14T17:20:38.680747Z",
     "shell.execute_reply": "2022-06-14T17:20:38.680237Z",
     "shell.execute_reply.started": "2022-06-14T17:20:38.672367Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img_path, img_label, transform=None):\n",
    "        self.img_path = img_path\n",
    "        self.img_label = img_label\n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        img = cv2.imread(self.img_path[index])            \n",
    "        img = img.astype(np.float32)\n",
    "        \n",
    "        img /= 255.0\n",
    "        img -= 1\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(image = img)['image']\n",
    "        img = img.transpose([2,0,1])\n",
    "        return img,torch.from_numpy(np.array(self.img_label[index]))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:20:39.676349Z",
     "iopub.status.busy": "2022-06-14T17:20:39.675757Z",
     "iopub.status.idle": "2022-06-14T17:20:39.684012Z",
     "shell.execute_reply": "2022-06-14T17:20:39.683504Z",
     "shell.execute_reply.started": "2022-06-14T17:20:39.676298Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "                \n",
    "        model = models.resnet18(True)\n",
    "        model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "        model.fc = nn.Linear(512, 11)\n",
    "        self.resnet = model\n",
    "        \n",
    "    def forward(self, img):        \n",
    "        out = self.resnet(img)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "47539bef-14a7-4881-85e9-0882ff295e6a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:17:26.635126Z",
     "iopub.status.busy": "2022-06-14T17:17:26.634936Z",
     "iopub.status.idle": "2022-06-14T17:17:26.715022Z",
     "shell.execute_reply": "2022-06-14T17:17:26.714037Z",
     "shell.execute_reply.started": "2022-06-14T17:17:26.635108Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train(train_loader, model, criterion, optimizer):\n",
    "    model.train()\n",
    "    train_loss = 0.0\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "\n",
    "        # compute output\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 20 == 0:\n",
    "            print(loss.item())\n",
    "            \n",
    "        train_loss += loss.item()\n",
    "    \n",
    "    return train_loss/len(train_loader)\n",
    "            \n",
    "def validate(val_loader, model, criterion):\n",
    "    model.eval()\n",
    "    \n",
    "    val_acc = 0.0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "            \n",
    "            val_acc += (output.argmax(1) == target).sum().item()\n",
    "            \n",
    "    return val_acc / len(val_loader.dataset)\n",
    "\n",
    "def predict(test_loader, model, criterion):\n",
    "    model.eval()\n",
    "    val_acc = 0.0\n",
    "    \n",
    "    test_pred = []\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(test_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            test_pred.append(output.data.cpu().numpy())\n",
    "            \n",
    "    return np.vstack(test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "c14b410f-de94-4f52-be5b-1377c34bda4a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:28:50.674806Z",
     "iopub.status.busy": "2022-06-14T17:28:50.674375Z",
     "iopub.status.idle": "2022-06-14T17:28:50.685897Z",
     "shell.execute_reply": "2022-06-14T17:28:50.684711Z",
     "shell.execute_reply.started": "2022-06-14T17:28:50.674757Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import albumentations as A\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_path.values[:-100], train_label.values[:-100],\n",
    "            A.Compose([\n",
    "            A.Resize(300, 300),\n",
    "            A.RandomCrop(256, 256),\n",
    "            # A.HorizontalFlip(p=0.5),\n",
    "            # A.RandomContrast(p=0.5),\n",
    "            A.RandomBrightnessContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=10, shuffle=True, num_workers=1, pin_memory=False\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_path.values[-100:], train_label.values[-100:],\n",
    "            A.Compose([\n",
    "            A.Resize(300, 300),\n",
    "            A.RandomCrop(256, 256),\n",
    "            # A.HorizontalFlip(p=0.5),\n",
    "            # A.RandomContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(test_path, [0] * len(test_path),\n",
    "            A.Compose([\n",
    "            A.Resize(300, 300),\n",
    "            A.RandomCrop(256, 256),\n",
    "            # A.HorizontalFlip(p=0.5),\n",
    "            # A.RandomContrast(p=0.5),\n",
    "        ])\n",
    "    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:25:36.627262Z",
     "iopub.status.busy": "2022-06-14T17:25:36.626663Z",
     "iopub.status.idle": "2022-06-14T17:25:36.870131Z",
     "shell.execute_reply": "2022-06-14T17:25:36.869688Z",
     "shell.execute_reply.started": "2022-06-14T17:25:36.627201Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = XunFeiNet()\n",
    "model = model.to('cuda')\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.SGD(model.parameters(), 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:27:01.624238Z",
     "iopub.status.busy": "2022-06-14T17:27:01.623630Z",
     "iopub.status.idle": "2022-06-14T17:28:13.730407Z",
     "shell.execute_reply": "2022-06-14T17:28:13.729738Z",
     "shell.execute_reply.started": "2022-06-14T17:27:01.624177Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.5642861127853394\n",
      "2.1016221046447754\n",
      "1.2280302047729492\n",
      "1.3694757223129272\n",
      "1.7242618799209595\n",
      "1.2328886985778809\n",
      "1.4003621000668098 0.86\n",
      "1.19557785987854\n",
      "0.9005592465400696\n",
      "1.1152288913726807\n",
      "1.7703393697738647\n",
      "1.1053175926208496\n",
      "0.9445725679397583\n",
      "1.2647339665684207 0.88\n",
      "1.7755447626113892\n",
      "1.2936674356460571\n",
      "0.8846734762191772\n",
      "0.9563775062561035\n",
      "0.9804204106330872\n",
      "0.9901953935623169\n",
      "1.1846859583566929 0.94\n"
     ]
    }
   ],
   "source": [
    "for _  in range(8):\n",
    "    train_loss = train(train_loader, model, criterion, optimizer)\n",
    "    val_acc = validate(val_loader, model, criterion)\n",
    "    \n",
    "    print(train_loss, val_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "6787d12a-b923-4777-abf5-4d90904551d1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:28:52.914808Z",
     "iopub.status.busy": "2022-06-14T17:28:52.914227Z",
     "iopub.status.idle": "2022-06-14T17:29:55.094253Z",
     "shell.execute_reply": "2022-06-14T17:29:55.093509Z",
     "shell.execute_reply.started": "2022-06-14T17:28:52.914760Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred = None\n",
    "\n",
    "for _ in range(10):\n",
    "    if pred is None:\n",
    "        pred = predict(test_loader, model, criterion)\n",
    "    else:\n",
    "        pred += predict(test_loader, model, criterion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:30:01.756862Z",
     "iopub.status.busy": "2022-06-14T17:30:01.756263Z",
     "iopub.status.idle": "2022-06-14T17:30:01.766716Z",
     "shell.execute_reply": "2022-06-14T17:30:01.766199Z",
     "shell.execute_reply.started": "2022-06-14T17:30:01.756802Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "submit = pd.DataFrame(\n",
    "    {\n",
    "        'image': [x.split('/')[-1] for x in test_path],\n",
    "        'label': pred.argmax(1)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:31:41.777676Z",
     "iopub.status.busy": "2022-06-14T17:31:41.777059Z",
     "iopub.status.idle": "2022-06-14T17:31:41.788083Z",
     "shell.execute_reply": "2022-06-14T17:31:41.787381Z",
     "shell.execute_reply.started": "2022-06-14T17:31:41.777626Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "submit['id'] = submit['image'].apply(lambda x: int(x.split('_')[1][:-4]))\n",
    "submit = submit.sort_values(by='id')\n",
    "submit[['image', 'label']].to_csv('submit2.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "557d4f52-046a-45e2-bf1d-ece5ff5238a7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2022-06-14T17:31:42.671503Z",
     "iopub.status.busy": "2022-06-14T17:31:42.671033Z",
     "iopub.status.idle": "2022-06-14T17:31:42.679484Z",
     "shell.execute_reply": "2022-06-14T17:31:42.678866Z",
     "shell.execute_reply.started": "2022-06-14T17:31:42.671463Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "72       1\n",
       "175      2\n",
       "75       3\n",
       "171      4\n",
       "76       5\n",
       "      ... \n",
       "266    322\n",
       "221    323\n",
       "233    324\n",
       "311    325\n",
       "272    326\n",
       "Name: id, Length: 326, dtype: int64"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "submit['id']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129ff32f-8444-4cf2-9997-0101960ea490",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
